import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import dgl
from collections import defaultdict as ddict
from ordered_set import OrderedSet

class TrainDataset(Dataset):
    """
    Training Dataset class.
    Parameters
    ----------
    triples: The triples used for training the model
    num_ent: Number of entities in the knowledge graph
    lbl_smooth: Label smoothing

    Returns
    -------
    A training Dataset class instance used by DataLoader
    """
    def __init__(self, triples, num_ent, lbl_smooth):
        self.triples = triples
        self.num_ent = num_ent
        self.lbl_smooth = lbl_smooth
        self.entities = np.arange(self.num_ent, dtype=np.int32)

    def __len__(self):
        return len(self.triples)

    def __getitem__(self, idx):
        ele = self.triples[idx]
        triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label'])
        trp_label = self.get_label(label)
        #label smoothing
        if self.lbl_smooth != 0.0:
            trp_label = (1.0 - self.lbl_smooth) * trp_label + (1.0 / self.num_ent)

        return triple, trp_label

    @staticmethod
    def collate_fn(data):
        triples = []
        labels = []
        for triple, label in data:
            triples.append(triple)
            labels.append(label)
        triple = torch.stack(triples, dim=0)
        trp_label = torch.stack(labels, dim=0)
        return triple, trp_label

    #for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
    def get_label(self, label):
        y = np.zeros([self.num_ent], dtype=np.float32)
        for e2 in label: 
            y[e2] = 1.0
        return torch.FloatTensor(y)


class TestDataset(Dataset):
    """
    Evaluation Dataset class.
    Parameters
    ----------
    triples: The triples used for evaluating the model
    num_ent: Number of entities in the knowledge graph

    Returns
    -------
    An evaluation Dataset class instance used by DataLoader for model evaluation
    """
    def __init__(self, triples, num_ent):
        self.triples = triples
        self.num_ent = num_ent

    def __len__(self):
        return len(self.triples)

    def __getitem__(self, idx):
        ele = self.triples[idx]
        triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label'])
        label = self.get_label(label)

        return triple, label

    @staticmethod
    def collate_fn(data):
        triples = []
        labels = []
        for triple, label in data:
            triples.append(triple)
            labels.append(label)
        triple = torch.stack(triples, dim=0)
        label = torch.stack(labels, dim=0)
        return triple, label

    #for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
    def get_label(self, label):
        y = np.zeros([self.num_ent], dtype=np.float32)
        for e2 in label: 
            y[e2] = 1.0
        return torch.FloatTensor(y)


class Data(object):

    def __init__(self, dataset, lbl_smooth, num_workers, batch_size):
        """
        Reading in raw triples and converts it into a standard format. 
        Parameters
        ----------
        dataset:           The name of the dataset
        lbl_smooth:        Label smoothing
        num_workers:       Number of workers of dataloaders
        batch_size:        Batch size of dataloaders

        Returns
        -------
        self.ent2id:            Entity to unique identifier mapping
        self.rel2id:            Relation to unique identifier mapping
        self.id2ent:            Inverse mapping of self.ent2id
        self.id2rel:            Inverse mapping of self.rel2id
        self.num_ent:           Number of entities in the knowledge graph
        self.num_rel:           Number of relations in the knowledge graph

        self.g:                 The dgl graph constucted from the edges in the traing set and all the entities in the knowledge graph
        self.data['train']:     Stores the triples corresponding to training dataset
        self.data['valid']:     Stores the triples corresponding to validation dataset
        self.data['test']:      Stores the triples corresponding to test dataset
        self.data_iter:		The dataloader for different data splits
        """
        self.dataset = dataset
        self.lbl_smooth = lbl_smooth
        self.num_workers = num_workers
        self.batch_size = batch_size

        #read in raw data and get mappings
        ent_set, rel_set = OrderedSet(), OrderedSet()
        for split in ['train', 'test', 'valid']:
            for line in open('./{}/{}.txt'.format(self.dataset, split)):
                sub, rel, obj = map(str.lower, line.strip().split('\t'))
                ent_set.add(sub)
                rel_set.add(rel)
                ent_set.add(obj)

        self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
        self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
        self.rel2id.update({rel+'_reverse': idx+len(self.rel2id) for idx, rel in enumerate(rel_set)})

        self.id2ent = {idx: ent for ent, idx in self.ent2id.items()}
        self.id2rel = {idx: rel for rel, idx in self.rel2id.items()}

        self.num_ent = len(self.ent2id)
        self.num_rel = len(self.rel2id) // 2

        #read in ids of subjects, relations, and objects for train/test/valid 
        self.data = ddict(list) #stores the triples
        sr2o = ddict(set) #The key of sr20 is (subject, relation), and the items are all the successors following (subject, relation)
        src=[]
        dst=[]
        rels = []
        inver_src = []
        inver_dst = []
        inver_rels = []

        for split in ['train', 'test', 'valid']:
            for line in open('./{}/{}.txt'.format(self.dataset, split)):
                sub, rel, obj = map(str.lower, line.strip().split('\t'))
                sub_id, rel_id, obj_id = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj]
                self.data[split].append((sub_id, rel_id, obj_id))

                if split == 'train': 
                    sr2o[(sub_id, rel_id)].add(obj_id)
                    sr2o[(obj_id, rel_id+self.num_rel)].add(sub_id) #append the reversed edges
                    src.append(sub_id)
                    dst.append(obj_id)
                    rels.append(rel_id)
                    inver_src.append(obj_id)
                    inver_dst.append(sub_id)
                    inver_rels.append(rel_id+self.num_rel)

        #construct dgl graph
        src = src + inver_src
        dst = dst + inver_dst
        rels = rels + inver_rels
        self.g = dgl.graph((src, dst), num_nodes=self.num_ent)
        self.g.edata['etype'] = torch.Tensor(rels).long()

        #identify in and out edges
        in_edges_mask = [True] * (self.g.num_edges()//2) + [False] * (self.g.num_edges()//2)
        out_edges_mask = [False] * (self.g.num_edges()//2) + [True] * (self.g.num_edges()//2)
        self.g.edata['in_edges_mask'] = torch.Tensor(in_edges_mask)
        self.g.edata['out_edges_mask'] = torch.Tensor(out_edges_mask)

        #Prepare train/valid/test data
        self.data = dict(self.data)
        self.sr2o = {k: list(v) for k, v in sr2o.items()} #store only the train data

        for split in ['test', 'valid']:
            for sub, rel, obj in self.data[split]:
                sr2o[(sub, rel)].add(obj)
                sr2o[(obj, rel+self.num_rel)].add(sub)

        self.sr2o_all = {k: list(v) for k, v in sr2o.items()} #store all the data
        self.triples  = ddict(list)

        for (sub, rel), obj in self.sr2o.items():
            self.triples['train'].append({'triple':(sub, rel, -1), 'label': self.sr2o[(sub, rel)]})

        for split in ['test', 'valid']:
            for sub, rel, obj in self.data[split]:
                rel_inv = rel + self.num_rel
                self.triples['{}_{}'.format(split, 'tail')].append({'triple': (sub, rel, obj), 'label': self.sr2o_all[(sub, rel)]})
                self.triples['{}_{}'.format(split, 'head')].append({'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]})

        self.triples = dict(self.triples)

        def get_train_data_loader(split, batch_size, shuffle=True):
            return  DataLoader(
                    TrainDataset(self.triples[split], self.num_ent, self.lbl_smooth),
                    batch_size      = batch_size,
                    shuffle         = shuffle,
                    num_workers     = max(0, self.num_workers),
                    collate_fn      = TrainDataset.collate_fn
                )

        def get_test_data_loader(split, batch_size, shuffle=True):
            return  DataLoader(
                    TestDataset(self.triples[split], self.num_ent),
                    batch_size      = batch_size,
                    shuffle         = shuffle,
                    num_workers     = max(0, self.num_workers),
                    collate_fn      = TestDataset.collate_fn
                )

        #train/valid/test dataloaders
        self.data_iter = {
            'train':        get_train_data_loader('train', self.batch_size),
            'valid_head':   get_test_data_loader('valid_head', self.batch_size),
            'valid_tail':   get_test_data_loader('valid_tail', self.batch_size),
            'test_head':    get_test_data_loader('test_head', self.batch_size),
            'test_tail':    get_test_data_loader('test_tail', self.batch_size),
        }
        